#!/usr/bin/env python
# coding: utf-8

# In[14]:


from datetime import datetime, timedelta
from netCDF4 import Dataset
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import cmocean.cm as cmo
from matplotlib import ticker 
import glob 
import matplotlib as mpl
from wrf import getvar, interplevel, get_cartopy, vertcross
from matplotlib.colors import LogNorm
import sys
from string import ascii_lowercase

mpl.rcParams['figure.figsize'] = [10,10]
mpl.rcParams['figure.titlesize'] = 10
mpl.rcParams['figure.titleweight'] = 'bold'
mpl.rcParams['xtick.labelsize'] = 11
mpl.rcParams['ytick.labelsize'] = 11
mpl.rcParams['axes.labelsize'] = 11
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['lines.linewidth'] = 1.8
mpl.rcParams['grid.linewidth'] = .25
mpl.rcParams['figure.subplot.wspace'] = 0.05
mpl.rcParams['figure.subplot.hspace'] = 0.05
mpl.rcParams['legend.fontsize'] = 11
mpl.rcParams['legend.framealpha'] = .75
mpl.rcParams['legend.loc'] = 'best'
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['savefig.dpi'] = 600
#plt.rcParams['axes.facecolor'] = 'red'

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

rdir = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/met/'
rdir2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/turb/'
rdir3 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/fire/'
#rdir = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012/run/netcdf_2/'
dx = 111.111/1000


# ## For making the vector plots xs without the vectors

# In[ ]:



lowerc = list(iter(ascii_lowercase))
from matplotlib.colors import LogNorm
#for tidx in np.arange(240, 271, 30):#[240, 360, 480]:
sdir = 'xs/'
if not os.path.exists(sdir):
    os.makedirs(sdir)
loc = []
#for tidx in np.arange(480, 541, 10):#[240, 360, 480]:
#offset = [-20,-15,-10,-5,0,5,10,15,20]
offset = np.arange(-30,11,1)
g = 9.81
dd = [30,30,30,30,31]
hh = [19,19,22,23,2]
mm = [0,45,5,0,30]
c = ['k','magenta','red','green','blue']
n = 6
c = cmo.thermal(np.linspace(0,1,n))
lw = 2.
lab = ['1900 UTC','1945 UTC','2205 UTC','2300 UTC','0230 UTC']
for tidx in np.arange(len(dd)):
    time = datetime(2021, 12, dd[tidx], hh[tidx], mm[tidx], 0)
#    time = time + timedelta(minutes=int(tidx))
    filename = 'wrfout_d02_%s' % time.strftime('%Y-%m-%d_%H:%M:%S')
    print (filename)
    with Dataset(rdir + filename) as nc:
        timestring = nc.variables['Times']
        xlat = getvar(nc, 'lat').values
        xlon = getvar(nc, 'lon')
        
    #        data['fmc_g'] = nc.variables['FMC_G'][tidx,:]
        u        = getvar(nc, 'ua').values
        v        = getvar(nc, 'va').values
        th       = getvar(nc, 'T').values
        th       = th + 300.
    #    tke      = getvar(nc, 'TKE').values
        ph       = getvar(nc, 'PH').values
        phb      = getvar(nc,'PHB').values
        zht      = ((ph[1:]+ph[0:-1])/2. +(phb[1:]+phb[0:-1])/2.)/9.81
        hgt      = getvar(nc,'HGT').values

#    winterp = interplevel()
#    uinterp = interplevel()
    
    indx = np.argmax(np.array(u))
    loc.append(np.unravel_index(indx, np.shape(u)))

    for kk in np.arange(len(offset)):
        # cross section over the fire
        x = 262 + 15 + offset[kk]
        y = 350

        #print (np.mean(xlat[x,:]), np.mean(xlat[x+10,:]), np.mean(xlat[x-10,:]))
    
        xlonv = xlon[x, :]
        zv    = np.squeeze(zht[:,x, :])
        uv    = np.squeeze(u[:,x,:])
        vv    = np.squeeze(v[:,x,:])
        thv   = np.squeeze(th[:,x,:])
        hgtv  = hgt[x, :]

        if kk == 0:
            pblh = np.empty([np.shape(thv)[1],len(offset)])
            mean_wspd = np.empty([np.shape(thv)[1],len(offset)])
            gp = np.empty([np.shape(thv)[1],len(offset)])
            fr = np.empty([np.shape(thv)[1],len(offset)])
        for jj in np.arange(np.shape(thv)[1]):
            thv_grad = (thv[1:36,jj] - thv[0:35,jj]) / (zv[1:36,jj] - zv[0:35,jj])
            thv_grad_max = np.argmax(thv_grad)
            pblh[jj,kk] = zv[thv_grad_max,jj] - hgtv[jj]
            #print ('Found PBL height at ' + str(pblh[jj,kk]) + ' m')
            mean_wspd[jj,kk] = np.mean(np.sqrt(pow(uv[0:thv_grad_max,jj],2.)+pow(vv[0:thv_grad_max,jj],2.)))
            #print ('Found mean PBL windspeed ' + str(mean_wspd[jj,kk]) + ' m/s')
            mean_ptemp = np.mean(thv[0:thv_grad_max,jj])
            #print ('Found mean PBL theta ' + str(mean_ptemp) + ' K')
            ft_ptemp = thv[thv_grad_max+2,jj]
            #print (zv[thv_grad_max,jj], zv[thv_grad_max+2,jj])
            #print ('Found FT theta ' + str(ft_ptemp) + ' K')
            gp[jj,kk] = g*(ft_ptemp-mean_ptemp)/mean_ptemp
            #print ('Found reduced gravity ' + str(gp[jj,kk]) + ' m s-2')
            fr[jj,kk] = mean_wspd[jj,kk] / np.sqrt(gp[jj,kk]*pblh[jj,kk])
            #print ('Found Fr ' + str(fr[jj,kk]))

    pblh_avg = np.nanmean(pblh,axis=1)
    gp_avg = np.nanmean(gp,axis=1)
    mean_wspd_avg = np.nanmean(mean_wspd,axis=1)
    fr_avg = np.nanmean(fr,axis=1)

    idxx = np.where(fr_avg<=1.0)[0]
    print (xlonv[idxx[0:5]])

    if tidx == 0:
        plt.figure(figsize=(16,8))
    plt.subplot(221)
    plt.grid()
    plt.xlim(xlonv[0],xlonv[-1])
    plt.ylim(0,3600)
    plt.plot(xlonv, pblh_avg, c=c[tidx], lw=lw)
    plt.ylabel('PBL Height (m)')
    plt.subplot(222)
    plt.grid()
    plt.xlim(xlonv[0],xlonv[-1])
    plt.ylim(0.05,0.42)
    plt.plot(xlonv, gp_avg, c=c[tidx], lw=lw)
    plt.ylabel("g' (m s$^{-2}$)")
    plt.subplot(223)
    plt.grid()
    plt.xlim(xlonv[0],xlonv[-1])
    plt.ylim(0,50)
    plt.plot(xlonv, mean_wspd_avg, c=c[tidx], lw=lw)
    plt.xlabel('Longitude')
    plt.ylabel('V (m s$^{-1}$)')
    plt.subplot(224)
    plt.grid()
    plt.xlim(xlonv[0],xlonv[-1])
    plt.ylim(0,8)
    plt.fill_between([xlonv[0],xlonv[-1]],1.0,10.,alpha=0.5,facecolor='gray')
    plt.plot(xlonv, fr_avg, c=c[tidx], lw=lw, label=lab[tidx])
    plt.xlabel('Longitude')
    plt.ylabel('Froude Number')
    plt.legend()
plt.tight_layout()
plt.savefig('xs_avg_fr_analysis_evolution_%s' % time.strftime('%4Y-%m-%d_%H:%M:%S'))#,bbox_inches="tight")
plt.close()

# ## For making the GIF File

# In[99]:


from PIL import Image

sys.exit()

# filepaths
fp_in = "./xs/*.png"
fp_out = "./xs/xs_xx.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
#img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
#frames = [Image.open(image) for image in sorted(glob.glob(fp_in))]
#frame_one = frames[0]
all_files = sorted(glob.glob(fp_in))
frames = []
for i in np.arange(len(all_files)):
    frame = Image.open(all_files[i])
    frames.append(frame.copy())
#img.save(fp=fp_out, format='GIF', append_images=imgs,
#         save_all=True, duration=240, loop=0)
frames[0].save(fp=fp_out, format='GIF', append_images=frames[1:],
         save_all=True, duration=300, loop=0)


# In[ ]:




